; If using Jupyter Lab, run
; jupyter labextension install @jupyterlab/javascript-extension
; before using this notebook.
(ns tutorial
(:refer-clojure :only [])
(:require
[clojure.repl :refer :all]
[metaprob.state :as state]
[metaprob.trace :as trace]
[metaprob.sequence :as sequence]
[metaprob.builtin-impl :as impl]
[metaprob.syntax :refer :all]
[metaprob.builtin :refer :all]
[metaprob.prelude :refer :all]
[metaprob.distributions :refer :all]
[metaprob.interpreters :refer :all]
[metaprob.inference :refer :all]
[metaprob.compositional :as comp]
[metaprob.examples.gaussian :refer :all]
[metaprob.tutorial.jupyter :refer :all]))
(enable-inline-viz)
Metaprob is a new, experimental language for probabilistic programming, embedded in Clojure. In this tutorial, we'll take a bottom-up approach to introducing the language. We'll begin by exploring Metaprob's core data types, including the flexible "trace" data structure at the heart of the language's design. From there, we'll look at the probabilistic semantics of Metaprob programs, and tackle some small statistical inference problems. By the end of this tutorial, we'll be querying a more sophisticated model of real US census data!
The Metaprob language is embedded in Clojure, and inherits several of Clojure's atomic datatypes: Clojure's numbers, keywords, booleans, and strings are all Metaprob values, too. Many of Clojure's functions for manipulating these values have Metaprob versions.
(+ 5 7)
If you want to use a Clojure function that is not available in Metaprob, you can fully qualify the procedure's name.
(clojure.core/str "Hello" " " "world")
(Any Clojure function can be used without modification in Metaprob. For reasons we'll cover a bit later, however, calling impure Clojure functions that have nondeterministic behavior (e.g., because they generate random values during their execution) must be done with care, in order not to break Metaprob's probabilistic semantics. Note that this restriction applies also to higher-order Clojure functions supplied with stochastic procedures as arguments.)
At first glance, it appears that Metaprob supports Clojure's vector and list datatypes, too. Here, we apply Metaprob's length function to a vector, then to a list.
(length [1 5 2])
(length (list 1 2 3))
But really, lists and vectors are just special cases of Metaprob's "trace" data structure:
(and (trace? [1 2 3]) (trace? '(1 2 3)))
(plot-trace [5 2 9] 200 200) ; produce a 200x200 diagram of the given trace
(plot-trace '(5 2 9) 500 100) ; produce a 500x100 diagram of the given trace
A trace is a tree-like data structure. Every trace:
Subtraces may be named with strings, numbers, or booleans. The name of each subtrace appears above its root node in the diagrams above.
So far, you have seen two special kinds of trace:
A vector is a trace with no value at its root, and n subtraces, labeled by the integers 0 through n-1. In a vector, each subtrace stores a value at its root, and has no children.
A list is either:
But most traces are not vectors or lists. Consider the following trace, for example:
(plot-trace (gen [x y] (+ x (* 3 y))))
As you may have guessed, the above value is a Metaprob procedure. (The gen macro, which we'll cover in more detail soon, is Metaprob's version of Clojure's fn, Scheme's lambda, or JavaScript's anonymous function.) In Metaprob, all compound data, including procedures, are stored as traces.
Let's look at some of the tools Metaprob provides for manipulating and creating traces.
We'll start with trace manipulation. We'll use the trace depicted above as a running example. Let's give it a name using the Metaprob keyword define. Below, we also demonstrate the pprint function, which is another way (apart from plot-trace) to get a human-readable representation of a trace. You can ignore the word COMPILED for now: it just indicates that the Metaprob has produced an efficient Clojure-compiled version of the procedure that this trace represents (something the gen macro is also responsible for).
(define example-trace (gen [x y] (+ x (* 3 y))))
(pprint example-trace)
One of the most basic things we can do with a trace is get its root value, using trace-get:
(trace-get example-trace)
Not all traces have values at their roots, and trace-get will throw an error if there is no value for it to return. So before you call trace-get, it is sometimes useful to call trace-has?, which checks if a trace has a root value.
; Recall that a vector is represented
; as a trace with no value at its root,
; and subtraces for each of its elements.
(trace-has? [1 2 3])
Both trace-has? and trace-get accept an optional second argument: an address. An address is either:
This version of trace-has? returns false if the given address is invalid for the trace, or if it is valid but the specified subtrace has no value at its root.
(plot-trace [5 7 9] 200 200)
; Get the value at the subtrace named 2
(trace-get [5 7 9] 2)
; There is no subtrace 4
(trace-has? [5 7 9] 4)
; There is no subtrace "rest"
(trace-has? '() "rest")
; The subtrace "rest" exists, but has no value at its root
(trace-has? '(1) "rest")
; The subtrace exists and has a value (2)
(trace-has? '(1 2) "rest")
; Returns the _value_ at the subtrace named "rest" -- 2 -- and not
; the subtrace itself -- which would be '(2).
(trace-get '(1 2) "rest")
; Gets the value in the "environment" subtrace.
(trace-get example-trace "environment")
; Gets the _value_ in the 0 subtrace of the "body" subtrace of the "generative-source" subtrace
(trace-get example-trace '("generative-source" "body" 0))
; There is no child subtrace called 4 of the '("generative-source" "body") subtrace.
(trace-has? example-trace '("generative-source" "body" 4))
trace-get-maybe for safe access¶In this exercise, you will write your first Metaprob function!
To define a function, use the form
(define func-name
(gen [arg1 arg2 ...]
body))
In the box below, define (trace-get-maybe tr adr default), which should return the value of tr at adr, or default if no value exists at that address.
; Solution
(define trace-get-maybe
(gen [tr adr default]
(if (trace-has? tr adr) (trace-get tr adr) default)))
vec-ref¶Write a function (vec-ref v n) that takes in a vector (a trace with subtraces labeled 0, 1, ...), and an element number n, and returns the nth element of the vector. If n is out-of-bounds for the vector, return the Clojure keyword :out-of-bounds.
As in Clojure, the body can contain multiple expressions, in which case they will be evaluated in turn and the last result will be returned.
; Solution
(define vec-ref
(gen [v n]
(trace-get-maybe v n :out-of-bounds)))
NB: trace-has? does not check whether a particular address exists in the trace, just whether there is a value at that trace. It could return false for the address '("a" "b") but not for '("a" "b" "c"). Similarly, trace-get does not return an entire subtrace at an address, but just the value at that address. (trace-get example-trace "generative-source") returns the string "gen", not the entire subtrace representing the example-trace source code.
In order to check for or retrieve a subtrace at an address, use trace-has-subtrace? and trace-subtrace. Unlike trace-has? and trace-get, these two functions require a second argument, which specifies the address of the subtrace in question.
(plot-trace example-trace)
; Extract the generative-source subtrace of our example.
(plot-trace (trace-subtrace example-trace "generative-source"))
; Extraction at a longer address.
(plot-trace (trace-subtrace example-trace '("generative-source" "body" 2)) 300 200)
We can represent an $n$-dimensional array in Metaprob using vectors of vectors. For example, a matrix (2D array) would be a vector of matrix rows, each of which is itself a vector of numbers:
(define example-2d-matrix [[1 2 3]
[4 5 6]
[7 8 9]])
Write a function (n-d-vec-ref mat indices) which takes in
mat, represented as a vector of vectors (of vectors, of vectors, ...), and:out-of-bounds (if the indices push beyond the array's bounds) or the value at the given address.(Food for thought: why doesn't (trace-get v indices) suffice to implement this operation? If confused, try using plot-trace on an N-D vector.)
; Solution
(define n-d-vec-ref
(gen [v indices]
; If `indices` is non-empty:
(if (trace-has? indices)
; Then check whether the first index exists as a subtrace of v
(if (and (trace? v) (trace-has? v (trace-get indices)))
; If it does, pull it out and make a recursive call
(n-d-vec-ref (trace-get v (trace-get indices)) (trace-subtrace indices "rest"))
; If it doesn't, we're out of bounds
:out-of-bounds)
; If `indices` is empty, we're done: return v
v)))
; Testing: should output true.
(and
(= 3 (n-d-vec-ref example-2d-matrix '(0 2)))
(= 7 (n-d-vec-ref example-2d-matrix '(2 0)))
(= :out-of-bounds (n-d-vec-ref example-2d-matrix '(1 1 2)))
(= :out-of-bounds (n-d-vec-ref example-2d-matrix '(1 3)))
(= :out-of-bounds (n-d-vec-ref example-2d-matrix '(3 1))))
In a Metaprob program she is working on, Alice has written the following expression:
(trace-subtrace (trace-subtrace (trace-get flip "implementation") "generative-source") "body")
This code is correct, but during code review, it strikes Bob as inelegant. He rewrites it as follows, taking advantage of the fact that functions like trace-get accept arbitrarily long "paths" of subtrace names as addresses:
(trace-subtrace flip '("implementation" "generative-source" "body"))
He is so pleased with the elegance of his code that he commits without testing. Oops!
Why doesn't Bob's rewrite work, when Alice's original code did? Can Alice's code still be made more concise in another way?
(Note: flip is a real Metaprob value, so you can run both snippets -- and your own proposed rewrite -- to test them out.)
; Solution
; `flip` has a subtrace called "implementation", but that subtrace has no children. In particular, "generative-source" is not a child of the "implementation" subtrace.
; So the address '("implementation" "generative-source") is meaningless for `flip`.
; Instead, the value stored at the root of `flip`'s "implementation" subtrace is itself a trace.
; Alice's code can be simplified somewhat, to:
(trace-subtrace (trace-get flip "implementation") '("generative-source" "body"))
; because '("generative-source" "body") is exactly the address that calling `(trace-subtrace (trace-subtrace ... "generative-source") "body")` (as Alice did) probes.
; Run this cell to visualize
(plot-trace (trace-get flip "implementation") 2000 2000)
In the exercise above, you saw a situation in which a trace's root value was itself a trace. Write flattened-trace-has-subtrace? and flattened-trace-subtrace, which work just like trace-has-subtrace? and trace-subtrace, with the following exception: if you reach a "dead end" while traversing an address, check to see if the value stored at that dead end is itself a trace (using trace?), and if so, continue traversing as if it were a subtrace.
This will enable you to write, e.g., (flattened-trace-subtrace flip '("implementation" "generative-source" "body")), or (flattened-trace-subtrace example-2d-matrix '(1 1)).
; Solution
(define flattened-trace-has-subtrace?
(gen [tr adr]
(or
; address is empty: we are already at the subtrace
(not (trace-has? adr))
; the first piece of the address exists as a direct subtrace,
; and the recursive call succeeds
(and
(trace-has-subtrace? tr (trace-get adr))
(flattened-trace-has-subtrace? (trace-subtrace tr (trace-get adr)) (trace-subtrace adr "rest")))
; our current node stores a trace, for which (flattened-trace-has-subtrace? . adr) succeeds.
(and
(trace-has? tr)
(trace? (trace-get tr))
(flattened-trace-has-subtrace? (trace-get tr) adr)))))
(define flattened-trace-subtrace
(gen [tr adr]
(cond
(not (trace-has? adr))
tr
(trace-has-subtrace? tr (trace-get adr))
(flattened-trace-subtrace (trace-subtrace tr (trace-get adr)) (trace-subtrace adr "rest"))
(and (trace-has? tr) (trace? (trace-get tr)))
(flattened-trace-subtrace (trace-get tr) adr))))
Metaprob provides two functions, (trace-set tr [adr] val) and (trace-set-subtrace tr adr sub), which return modified versions of traces:
(trace-set tr val) changes the value of the trace tr's root node to val.(trace-set tr adr val) changes the root value of the subtrace at the address adr to val.(trace-set-subtrace tr adr sub) changes the entire subtrace of tr at adr to sub.These functions do not change the original trace, but rather create a copy of the trace with some value or subtrace changed. Metaprob does have mutable traces, and operators like trace-set! and trace-set-subtrace! that operate on them, but we will not have a need for them in this tutorial.
Predict the results of each of the following expressions:
(clojure.core/refer-clojure :only '[println])
; 1.
(println "Problem 1")
(pprint (trace-set [5 7 9] 0))
; 2.
(println "Problem 2")
(pprint (trace-set [5 7 9] 0 0))
; 3.
(println "Problem 3")
(pprint (trace-set [5 7 9] 5 2))
; 4.
(println "Problem 4")
(pprint (trace-set-subtrace '(1 2 3 4) '("rest" "rest") '(5 6 7)))
; 5.
(println "Problem 5")
(pprint (trace-set-subtrace (gen [] 5) '("generative-source" "pattern")
(trace-subtrace (gen [x y z] (+ x y z))
'("generative-source" "pattern"))))
; 6.
(println "Problem 6")
(pprint (trace-set-subtrace [5 7 9] 2 '(5 7 9)))
; 7.
(println "Problem 7")
(pprint (trace-set [5 7 9] 2 '(5 7 9)))
There is also a trace-delete function, which clears the value in a trace's root node. Optionally, it can take an address; in that case, it clears the root value of the subtrace at that address (but does not delete the subtrace).
(plot-trace (trace-delete [5 7 9] 2) 200)
In the above example, notice that although the 2 subtrace has no value, it still exists in the trace.
(trace ...) is Metaprob's all-purpose constructor for traces. If you pass it no arguments, you get the empty trace:
(plot-trace (trace) 100)
One way to construct a bigger trace is to start out with an empty one and use trace-set and trace-set-subtrace to add values and sub-traces. This can be quite tedious, so (trace ...) also supports an argument syntax that is demonstrated in the examples below:
; Trace with a root value
(plot-trace (trace :value 10) 75)
; Trace with subtraces with root values
(plot-trace (trace "Massachusetts" "Boston", "California" "Sacramento", "New York" "New York") 300)
; Trace with root value and subtraces with root values
(plot-trace
(trace :value "John Doe", "first name" "John", "last name" "Doe", "address" "500 Penn Ave.", "age" 94)
400)
; Trace with arbitrary named subtrace, using (** ...)
(plot-trace
(trace "a" (** (trace "b" "c")) "b" 5)
300 100)
; Trace with multiple named subtraces and a value
(plot-trace
(trace :value "phrases",
"green" (** (trace :value "first word",
"eggs" (** (trace "and" (** (trace "ham" "!")))),
"peace" ".",
"ideas" (** (trace "sleep" (** (trace "furiously" "?"))))))))
As a way to get used to these trace operations, in this exercise, we'll implement binary search trees in Metaprob using traces. For our purposes, a binary search tree is either:
For this exercise, write:
(trace ...) function, with the numbers 1 through 10. (There are multiple valid binary search trees containing these numbers; any organization that is consistent with the above rules is fine.)(insert tree i), which inserts a number i into tree.(contains? tree i) which checks whether a tree contains a given node.; Solution
(define example-tree
(trace :value 5
"left" (** (trace :value 3
"left" (** (trace :value 1 "right" 2))
"right" 4))
"right" (** (trace :value 7
"left" 6
"right" (** (trace :value 9 "left" 8 "right" 10))))))
(define insert
(gen [tree i]
(if (empty-trace? tree)
(trace-set tree i)
(block
(define branch (if (<= i (trace-get tree)) "left" "right"))
(trace-set-subtrace
tree
branch
(insert (if (trace-has-subtrace? tree branch)
(trace-subtrace tree branch)
(trace)) i))))))
(define contains?
(gen [tree i]
(and (not (empty-trace? tree))
(or
(= (trace-get tree) i)
(block
(define branch (if (<= i (trace-get tree)) "left" "right"))
(and (trace-has-subtrace? tree branch)
(contains? (trace-subtrace tree branch) i)))))))
(clojure.core/every? (gen [n] (contains? example-tree (+ n 1))) (range 10))
You've now written several functions that traverse traces recursively, but you always knew the general structure of the trace you were traversing. For cases when you don't know ahead-of-time how many children a trace will have -- or what its subtraces will be called -- the (trace-keys tr) function comes in handy. It lists the names of all subtraces of a given trace.
(trace-keys example-trace)
Note that Metaprob has a version of Clojure's map function, which often comes in handy when dealing with lists of subtrace names:
; Get values at each subtrace
(map (gen [name] (trace-get example-trace name)) (trace-keys example-trace))
Metaprob also has apply, which applies a procedure to a list of arguments:
; Concatenate the names of a trace's children
(apply clojure.core/str (trace-keys example-trace))
Currently, the plot-trace function draws trace diagrams of whatever size the user specifies. In this exercise, you'll write a function smart-plot-trace that automatically decides on a size for the trace, based on its breadth and depth:
(trace-depth tr), which recursively computes the maximum depth of a trace.(trace-breadth tr i), which gives the trace's breadth at level i. (Hint: sum the widths of the sub-traces at level i-1.)(max-trace-breadth tr) which gives the maximum breadth of a trace at any of its levels. (Note: the implementation we are suggesting in this exercise, which simply calls trace-breadth once for each level of the trace, is $O(n^2)$; there are certainly more efficient algorithms!)smart-plot-trace that calls plot-trace, passing in reasonable values for diagram width and height (based on the depth and breadth of the trace).; Bring in some useful Clojure functions/macros
(clojure.core/refer-clojure :only '[empty? max])
; Solution
(define trace-depth
(gen [tr]
(if (empty? (trace-keys tr))
1
(apply max
(map (gen [k]
(+ 1 (trace-depth (trace-subtrace tr k)))) (trace-keys tr))))))
(define trace-breadth
(gen [tr i]
(if (= i 0)
1
(apply
+
(map (gen [k]
(trace-breadth (trace-subtrace tr k) (- i 1))) (trace-keys tr))))))
(define max-trace-breadth
(gen [tr]
(apply max (map (gen [i] (trace-breadth tr i)) (range (trace-depth tr))))))
(define smart-plot-trace
(gen [tr]
(plot-trace tr (* (trace-depth tr) 100)
(* (max-trace-breadth tr) 80))))
(smart-plot-trace (gen [x y] (+ x (* 3 y))))
(smart-plot-trace (trace-get flip '("implementation")))
(smart-plot-trace
(nth (infer
:procedure comp/infer-apply
:inputs [(gen [] (flip 0.5)) [] (trace) (trace) true]) 1))
Write a function (delete-subtrace tr adr) that completely deletes the addressed subtrace from a trace.
Write a function (all-addresses-with-values tr) which returns a list of addresses at which tr has a value. (The built-in function addresses-of can already do this.)
(define trace-set-values-at
(gen [t & pairs]
(if (empty? pairs)
t
(block
(define addr (first pairs))
(define value (first (rest pairs)))
(define others (rest (rest pairs)))
(apply trace-set-values-at (pair (trace-set t addr value) others))))))
So far, every procedure we've written has been deterministic: the outputs are completely determined by the inputs, and are the same every time we run the code. Now, we turn our attention to non-deterministic procedures, with which we can do the sort of probabilistic modeling that Metaprob was made for.
There are many interesting questions that can be phrased using the language of probability:
Before we can even begin to answer these questions, though, we need a model: a (sometimes simplified) mathematical description of how we imagine the real-world processes behind these quesitons (the job market, human biology, authorship) actually work.
One of the key insights in probabilistic programming is that writing a program to simulate a process is a good way to specify our probabilistic models.
A few examples will make this clearer.
Suppose you are interviewing for a position at a financial firm, and your interviewer produces a coin. He flips the coin 100 times, and you observe 100 heads. What is the probability, he asks you, that the next flip, too, will come up heads?
Your answer will depend on your model of the random process you're observing.
Here is one possible model:
(define coin-model-1
(gen []
(replicate 100 (gen [] (flip 0.5)))))
Here, we have used replicate, which takes a number (100, in this case) and a 0-argument procedure, and runs the procedure that many times, returning a list of the results from each run. In this model, we hypothesize that the underlying process at work is as follows: a fair coin is flipped 100 times. We are 100% sure that the coin is fair: hence the constant 0.5.
Another model might not be so trusting of the interviewer:
(define coin-model-2
(gen []
(define p (uniform 0 1))
(replicate 100 (gen [] (flip p)))))
Here, we imagine that the interviewer randomly decided on the coin's weight (a number between 0 and 1) before we entered the room, and procured a biased coin to flip for us.
Perhaps we don't really believe that's possible, though: how does one get a "weighted coin" anyway? Maybe we really believe only three possibilities going in: the coin is fair, a "heads" is painted on both sides, or a "tails" is painted on both sides. Let's try to model that scenario:
(define coin-model-3
(gen []
(define p (uniform-sample [0 0.5 1]))
(replicate 100 (gen [] (flip p)))))
The "gambler's fallacy" is (according to Wikipedia) "the mistaken belief that, if something happens more frequently than normal during a given period, it will happen less frequently in the future." Create a model coin-model-4 that posits the existence of a "God-like" figure who adjusts the weight of the coin after each flip to help ensure, at every step, that the whole sequence of flips will be as evenly split between Heads and Tails as possible.
; Solution
(define reduce
(gen [f start l]
(if (empty-trace? l)
start
(reduce f (f start (first l)) (rest l)))))
(define coin-model-4
(gen []
(reduce
(gen [l n]
(define current-ratio (/ (length (filter (gen [x] x) l)) n))
(pair (flip (- 1 current-ratio)) l))
(list (flip 0.5))
(map (gen [x] (+ x 1)) (range 99)))))
Research shows that men are consistently paid more than women for the same work (OECD), and also that no matter their gender, taller people are paid more (Case and Paxson, 2006). Of course, gender and height are correlated: men tend to be taller than women. That research in hand, we may want to answer questions like:
We can encode the research on the subject into a model. Below, the numbers for average height and for the relationship beteween height and pay is from Case and Paxson. The median salary data for men and women is from the Bureau of Labor Statistics. That said, this is still a highly simplified model.
(define wage-gap-model
(gen []
(define sex (uniform-sample [:m :f]))
(define height-mean (if (= sex :m) 70 64))
(define height (gaussian height-mean 4))
(define base-salary-mean (if (= sex :m) 44600 36920))
(define height-adjusted-salary-mean (* base-salary-mean (exp (* (- height height-mean) (if (= sex :m) 0.024 0.019)))))
(define salary (gaussian height-adjusted-salary-mean 2000))
salary))
(wage-gap-model)
Many queries can be formed as questions about averages. For example, we might wonder, under our wage gap model, what the average salary is (regardless of gender or height). Or, under our various coin flip models, we might ask what the average number of heads is. More interestingly, we can also ask about averages under some condition: given that the first ten flips were heads, what is the average total number of heads?
In these contexts, the word "average" is not quite accurate. We can take averages of actual lists of numbers, but our models are not lists of numbers: they are stochastic processes that produce numbers (and other data).
What we are really asking about is something called an "expectation." If we have a generative probabilistic model p, and a function f that takes in samples from p and outputs numbers, the expectation of f with respect to the distribution p is the weighted average of the values taken by (f x) as x is drawn from the distribution p—we weight more probable values of x more heavily, and less probable values more lightly. The expectation can be obtained (with probability 1) by taking the limit of the average, as our number of samples goes to infinity, of samples of (f (p)).
For example, here is a function count-heads, which will play the role of f. It takes a sample from one of our coin flip models, and returns the number of heads observed.
(define count-heads
(gen [flips]
(length (filter (gen [x] x) flips))))
(count-heads (coin-model-1))
(count-heads (coin-model-2))
(count-heads (coin-model-3))
Above, we've taken one sample of (f (p)) for each coin flip model p. What we'd like to do is take the expectation -- the expected average over infinitely many samples. There are ways to calculate this number exactly; indeed, you probably have a good intuition for what it should be in the three cases above. But in more complex models, that can be harder, so we will focus here on estimation.
The simplest technique for estimating an expectation is taking a lot of samples and averaging them. Let's write that up:
(define avg (gen [l] (clojure.core/float (/ (apply + l) (length l)))))
(define mc-expectation
(gen [p f n]
(avg (map f (replicate n p)))))
(mc-expectation coin-model-1 count-heads 1000)
(mc-expectation coin-model-2 count-heads 1000)
(mc-expectation coin-model-3 count-heads 1000)
We can also use this simple technique to ask about probabilities of events. For example, to estimate the probability under each model of attaining at least 75 heads, we can estimate the expectation of a function that returns one when the condition holds, and zero otherwise. We call this an indicator function.
Write a function indicator which takes in a condition (a predicate that accepts or rejects a given value) and returns an indicator function for that condition (a function that returns 0 or 1 based on whether its input satisfies the condition). Use this to create an indicator for the condition in which there are more than 75 heads in an experiment, and estimate the probability of this event under coin models 1, 2, and 3.
; Solution:
(define indicator
(gen [condition] (gen [x] (if (condition x) 1 0))))
(mc-expectation coin-model-1 (indicator (gen [l] (>= (count-heads l) 75))) 2000)
(mc-expectation coin-model-2 (indicator (gen [l] (>= (count-heads l) 75))) 2000)
(mc-expectation coin-model-3 (indicator (gen [l] (>= (count-heads l) 75))) 2000)
Under what conditions is mc-expectation likely to give a misleading answer? (Hint: under coin-model-1, is the probability of 75 heads really 0? What is the expected amount won in a lottery?)
Let's try to apply our mc-expectation procedure to our wage-gap model, first to answer the question, "Under our model, what is the average person's height?"
We immediately hit a snag: last time, the number of heads was a straightforward transformation of coin-model-i's return value. Here, wage-gap-model returns only a salary, from which we cannot recover the height of the person whose salary it is.
We could alter the model code to return more values—say, [sex height salary] as a vector—but there's another way.
So far, we've seen only one way of interacting with a Metaprob procedure: running it. But Metaprob provides a lot more flexibility than that, via its powerful infer operator.
One of the most basic things we can ask infer to do for us is trace the random choices made by a Metaprob procedure:
(infer
:procedure wage-gap-model)
infer runs the given procedure and returns three values: the procedure's return value (in this case, a salary), an execution trace that records the random choices made during the procedure's execution, and a "score" (which is zero in this case — we'll ignore it for now).
Let's look closer at the second return value—the execution trace.
(define [salary tr _]
(infer
:procedure wage-gap-model))
(smart-plot-trace tr)
wage-gap-model makes three random choices during its execution, and as such, its execution trace records three values. The addresses at which they are stored reflect where in wage-gap-model's soure code they were made. For example, the first recorded value is the choice :f, at address '(0 "sex" "uniform-sample"). This indicates that the choice in question was made on line 0, while defining the variable sex, using the procedure uniform-sample.
Let's rewrite our mc-expectation procedure to run f on the execution trace of p, rather than its return value:
(define mc-expectation-v2
(gen [p f n]
(avg
(map f (replicate n
(gen []
(define [_ tr _] (infer :procedure p)) tr))))))
We can now answer our question about average height:
(define person-height
(gen [t] (trace-get t '(2 "height" "gaussian"))))
(mc-expectation-v2 wage-gap-model person-height 1000)
This is quite close to the true answer of 67.
It can be cumbersome to type the addresses manually. Let's give them short names:
(addresses-of tr)
(define [sex-adr height-adr salary-adr] (addresses-of tr))
Estimating an expectation can be useful, but one of the advantages of Bayesian inference is that we can also quantify our uncertainty. By taking a lot of samples and plotting them in a histogram, we can get a sense of what an entire distribution looks like.
(histogram "Model 1"
(map count-heads (replicate 1000 coin-model-1)) [0 100] 1)
(histogram "Model 2" (map count-heads (replicate 100 coin-model-2)) [0 100] 1)
(histogram "Model 3" (map count-heads (replicate 400 coin-model-3)) [0 100] 1)
Although we can sometimes learn interesting things by taking expectations with respect to our models, the most exciting questions usually require conditioning on some piece of data. We might wonder what a person's expected salary is given that she is a woman, or with what probability a person is a woman given that they make $42,000 per year.
Let's look at the simpler of those two questions first. One way to estimate an answer would be to run many simulations in which we force the model to assign to sex the value :f, then average the resulting salaries.
We can intervene on our model and force it to make a certain choice by passing an intervention trace to infer. infer runs the model as usual, but when it encounters a random choice with an address that appears in the intervention trace, instead of generating a new choice, it simply reuses the one in the intervention trace.
In the code below, notice the modified call to infer, and also the manner in which we construct an intervention trace using the address of the choice we wish to control:
; First, create a version of mc-expectation that accepts an intervention trace
(define mc-expectation-v3
(gen [p intervene f n]
(avg
(map f (replicate n
(gen []
(define [_ tr _] (infer
:procedure p
:intervention-trace intervene)) ; NEW
tr))))))
; Create our intervention trace
(define ensure-female
(trace-set (trace) sex-adr :f))
; Create our accessor (the f to take the expectation of)
(define get-salary
(gen [t] (trace-get t salary-adr)))
; Run the query
(mc-expectation-v3 wage-gap-model ensure-female get-salary 1000)
This is quite a limited technique, however. If we wanted to ask, for instance, about someone's expected height, given that their salary is $42,000, an intervention would not do the trick. Do you see why? Here's what that (wrong) code would look like:
; Create our intervention trace
(define ensure-42k
(trace-set (trace) salary-adr 42000))
; Create our accessor (the f to take the expectation of)
(define get-height
(gen [t] (trace-get t height-adr)))
; Run the query
(mc-expectation-v3 wage-gap-model ensure-42k get-height 1000)
; Run the query with NO intervention — just calculates average height, regardless of salary
(mc-expectation-v3 wage-gap-model (trace) get-height 1000)
Why are these two numbers (almost) the same? Can you characterize the situations when interventions do work as a method of conditioning? Write code for one conditional query that can be answered with interventions, and one that can't.
In order to answer more sophisticated conditional queries, we need to turn to another technique: rejection sampling. In rejection sampling, we run our model over and over until we get a sample that satisfies our condition. To estimate a conditional expectation using rejection sampling, we can do this $n$ times, and average the values f takes on the $n$ samples that satisfy our condition.
(define rejection-sample
(gen [p condition f]
(define [_ t _] (infer :procedure p))
(if (condition t) (f t) (rejection-sample p condition f))))
(define rejection-expectation
(gen [p condition f n]
; Define a modified version of p which tries again
; if its first execution doesn't satisfy the condition.
; Call it n times
(avg (replicate n (gen [] (rejection-sample p condition f))))))
Now, suppose we want to know the expected height of someone making $\$56,000$ — a very high salary under our model. If we expressed our condition as "a salary of exactly $\$56,000$", it would take a very long time to get even one sample. So instead, we ask about a small interval around $\$56,000$. It is still quite slow:
(rejection-expectation
wage-gap-model
(gen [t] (< 55500 (trace-get t salary-adr) 56500))
(gen [t] (trace-get t height-adr))
100)
If we make the same query, but for an interval around $\$46,000$, it runs much faster:
(rejection-expectation
wage-gap-model
(gen [t] (< 45500 (trace-get t salary-adr) 46500))
(gen [t] (trace-get t height-adr))
100)
This is because our new condition is much more probable. As such, it takes fewer "tries" to produce a sample that satisfies it.
Write two rejection sampling queries to estimate (a) the probability that someone making between $\$45,500$ and $\$46,500$ is female, and (b) the average height of a woman whose salary is within that range. Before you run each one, try to estimate (roughly, qualitatively) how fast or slow it will be.
; Solution (a)
(rejection-expectation
wage-gap-model
(gen [t] (< 45500 (trace-get t salary-adr) 46500))
(indicator (gen [t] (= (trace-get t sex-adr) :f)))
100)
; Solution (b)
(rejection-expectation
wage-gap-model
(gen [t] (and (= (trace-get t sex-adr) :f) (< 45500 (trace-get t salary-adr) 46500)))
(gen [t] (trace-get t height-adr))
100)
With rejection sampling, we can begin to see the first applications of probabilistic modeling to data analysis. The idea is that we condition our model's execution on the actual data we've observed, then ask questions about the latent variables — the pieces we may not have observed directly.
As one example, consider the following model:
(define hybrid-coin-model
(gen []
(define which-model (uniform-sample [coin-model-1 coin-model-2 coin-model-3]))
(which-model)))
Given that we saw 100 heads, we can now ask about likely values of which-model. This allows us to pose questions like: if I saw 100 heads, which of my three possible explanations was most likely? And what are the chances the coin comes up heads next time?
(define count-heads-in-trace
(gen [t]
(apply + (map (gen [adr] (if (and (boolean? (trace-get t adr)) (trace-get t adr)) 1 0))
(addresses-of t)))))
; Probability of Coin Model 3 if I saw 100 heads:
(rejection-expectation
hybrid-coin-model
(gen [t] (= (count-heads-in-trace t) 100))
(indicator (gen [t] (= coin-model-3 (trace-get t '(0 "which-model" "uniform-sample")))))
100)
; Estimated probability of next coin flip being heads if I saw 100 heads
(rejection-expectation
hybrid-coin-model
(gen [t] (= (count-heads-in-trace t) 100))
(gen [t]
(define actual-model (trace-get t '(0 "which-model" "uniform-sample")))
(cond
(= actual-model coin-model-1) 0.5
(= actual-model coin-model-2) (trace-get t '(1 "which-model" 0 "p" "uniform"))
(= actual-model coin-model-3) (trace-get t '(1 "which-model" 0 "p" "uniform-sample"))))
100)
This is (in most runs) almost 1, but indicates that there is still some uncertainty.
So far, we have been making "point estimates" to answer our queries, summarizing our answers as a single number. For example, suppose suppose I tell you I have flipped a coin 100 times and it came up heads more than sixty times. I ask you to guess how many times exactly it came up heads. Using the hybrid-coin-model, you could compute an expectation, which is one way to generate a single-number answer:
(rejection-expectation
hybrid-coin-model
(gen [t] (> (count-heads-in-trace t) 60))
count-heads-in-trace
100)
But one of the main advantages of the sort of probabilistic modeling we're doing is that we have access to much richer information.
To introduce some terminology, the process of "Bayesian data analysis" consists of the following steps:
Step 1. First, express our prior beliefs about a generative process by encoding them into a model. The distribution over traces induced by that model is called the prior: it encodes our beliefs prior to seeing any data. In our case, the prior distribution over "number of heads" looks like this:
(histogram "Prior" (replicate 500 (gen [] (count-heads (hybrid-coin-model)))) [0 100] 1)
As we can see by examining this plot, although we believe that any number of heads is theoretically possible, we would be much more surprised to see, say, 75 heads than 0, 50, or 100. We can also plot our prior beliefs about which of the three models is being used, and check that all three are basically equally likely:
(define model-number
(gen [t]
(define which (trace-get t '(0 "which-model" "uniform-sample")))
(cond
(= which coin-model-1) 1
(= which coin-model-2) 2
(= which coin-model-3) 3)))
(histogram "Prior on model choice" (replicate 500 (gen [] (model-number (nth (infer :procedure hybrid-coin-model) 1)))) [1 3] 1)
Step 2. Condition on observed data, retrieving the posterior distribution -- like the prior distribution, but updated to reflect our new beliefs after seeing the data. In this case, we can observe that over 60 coins came up heads:
(define plot-posterior-rejection
(gen [p condition f n [min-val max-val] step]
(histogram "Posterior" (replicate n (gen [] (rejection-sample p condition f))) [min-val max-val] step)))
(plot-posterior-rejection hybrid-coin-model (gen [t] (> (count-heads-in-trace t) 60)) count-heads-in-trace 100 [0 100] 1)
As we can see, our point estimate above will in most cases be quite misleading!
We can also plot the posterior on which-model:
(plot-posterior-rejection hybrid-coin-model (gen [t] (> (count-heads-in-trace t) 60)) model-number 500 [1 3] 1)
Write a query that plots the salary distribution for people over 70 inches tall:
; Solution
(plot-posterior-rejection
wage-gap-model
(gen [t] (> (trace-get t height-adr) 70))
(gen [t] (trace-get t salary-adr))
300
[30000 60000] 500)
In the last two sections, we saw two ways of generating conditional samples:
Intervening on our model to force the "random choices" of interest to take on the values we want them to. This is fast, but is only correct in select cases, where the random choices we're conditioning on don't depend on any prior random choices.
Rejection sampling. This gives exact samples from the posterior; the only problem is how slow it is.
We'd like the speed of (1) with the correctness of (2). As a step toward getting there, let's think a bit more about why exactly (1) is usually incorrect.
Suppose we are trying to estimate the expected salary for someone who is 70 inches tall (5' 10"). We can get a (slow) estimate via rejection sampling:
(rejection-expectation
wage-gap-model
(gen [t] (< 69.9 (trace-get t height-adr) 70.1))
(gen [t] (trace-get t salary-adr))
100)
If we do the same using our intervention trace method, we get a fast but inaccurate answer:
(mc-expectation-v3
wage-gap-model
(trace-set (trace) height-adr 70)
(gen [t] (trace-get t salary-adr))
100)
Why is the answer lower in the intervention trace method? Under our model, men are more likely to be 70 inches tall than women, so if we know someone is 70 inches tall, we might expect it's more likely for them to be a man, and—incorporating the wage gap—we might expect their salary to be higher. The rejection sampling algorithm takes this into account. It simply waits for 100 samples where the height was 70 inches, the majority of which will be men; the salary will have been higher for those samples, so the expected salary will be higher.
The intervention method, on the other hand, is impatient, and immediately sets height to 70 in every run, severing the tie between gender and height. That means that our collection of 100 samples will contain just as many 70-inch-tall women as 70-inch-tall men, which is improbable under our model.
What if we made a compromise? We still artificially set the height of each sample to 70, but also measure how likely the model would have been to choose 70 anyway—a number that will be higher for sex=:m samples than for sex=:f samples. (For those familiar with probability theory, we're talking about probability densities: p(height=70|sex).) At the end, we take a weighted average of the salaries, where each salary is weighted by the likelihood that the employee really would have been 70 inches (given the random choices we'd already made about them: their sex).
To implement this, infer has one final trick up its sleeve: we can provide a target trace, which works just like an intervention trace, but also causes the constrained choices to be scored. For each targeted choice, Metaprob computes the probability that it would have made the same choice under the prior, and accumulates these numbers (as a sum of log probabilities) into a score that is returned, along with the model's output and execution trace, from infer.
We can use this feature to implement the likelihood weighting algorithm:
; Note -- we exponentiate the scores, because they are returned as log probabilities
(define weighted-samples
(gen [proc target f n]
(define samples (replicate n (gen [] (infer :procedure proc :target-trace target))))
(map (gen [[o t s]]
[(f t) (exp s)]) samples)))
(define lh-weighting
(gen [proc target f n]
(define samples (weighted-samples proc target f n))
(/ (apply + (map (gen [[v s]] (* v s)) samples))
(apply + (map (gen [[_ s]] s) samples)))))
Let's apply the technique to the question we explored above, to estimate the expected salary for someone who is 70 inches tall:
(lh-weighting wage-gap-model
(trace-set (trace) height-adr 70)
(gen [t] (trace-get t salary-adr))
100)
This brings us quite close to the rejection-sampled estimate, and much faster!
This same idea can be used to produce (and plot) samples from the posterior. The trick is this: to generate a single sample from the (approximate) posterior, we sample some number $n$ of "particles" with likelihood weights. We then choose one of them at random, giving each one probability proportional to its score. This variant of the technique is called "sampling/importancee resampling." Let's implement it:
(define sample-importance-resample
(gen
[proc target f n-particles]
; Generate particles
(define particles (weighted-samples proc target f n-particles))
; Choose one at random, with prob. proportional to importance weight
(define sum-of-weights (apply + (map (gen [[_ s]] s) particles)))
(define which-particle (categorical (map (gen [[o s]] (/ s (+ 1e-10 sum-of-weights))) particles)))
; Return only the sampled value (not the score)
(define [sampled-value _] (nth particles which-particle))
sampled-value))
(define plot-posterior-importance
(gen [p target f particles-per-sample n [min-val max-val] step]
(histogram "Posterior (Importance Resampling)" (replicate n (gen [] (sample-importance-resample p target f particles-per-sample))) [min-val max-val] step)))
One last time, we consider the problem of estimating someone's salary given that they are 73 inches tall. This time, instead of using a point estimate, we'll plot the entire posterior, which will give us a better sense of how salaries might vary within the population of 73-inch-tall people.
Let's first plot using only one particle. In this case, we are not getting the posterior at all: we are just sampling from the prior with an intervention (setting height=73).
; 1 particle -- sampling from the (intervened) prior
(plot-posterior-importance
wage-gap-model
(trace-set (trace) height-adr 73)
(gen [t] (trace-get t salary-adr))
1 1000
[30000 60000] 500)
Using 5 particles per sample, we get a better approximation to the posterior:
; 5 particles -- better approx to posterior
(plot-posterior-importance
wage-gap-model
(trace-set (trace) height-adr 73)
(gen [t] (trace-get t salary-adr))
5 500
[30000 60000] 500)
To check our work, we can revert to the rejection sampling technique to get exact samples from the posterior:
(plot-posterior-rejection
wage-gap-model
(gen [t] (< 72.9 (trace-get t height-adr) 73.1))
(gen [t] (trace-get t salary-adr))
500
[30000 60000] 300)
This looks pretty similar. The upshot is that the sampling/importance resampling method gives us an accurate picture of the posterior in a much more efficient manner.
In this section, we'll review all the techniques we've learned on a new example model, and look at ways of extending our models with model-specific "custom inference procedures."
The model we're using in this section is a bivariate Gaussian. A bivariate Gaussian models pairs of variables that each follow a "bell curve," but are correllated with one another — for example, a person's height and weight.
Ignoring weight, people's heights are distributed normally, with a high probability at the mean and lower probability the further you get away from the mean. Here we plot samples from a univariate Gaussian with mean 70 and standard deviation 3; standard deviation is a measure of the "spread" of a distribution.
(histogram-with-curve
"Gaussian (height)"
(replicate 5000 (gen [] (gaussian 70 3)))
(gen [x] (exp (score-gaussian x [70 3])))
[50 90])
And weight may also be distributed normally, with different mean and standard deviation:
(histogram-with-curve "Gaussian (weight)"
(replicate 5000 (gen [] (gaussian 150 20)))
(gen [x] (exp (score-gaussian x [150 20])))
[80 220])
So one idea for how to model a person's height and weight would be to draw them individually from these distributions:
(define person-v1
(gen []
(define height (gaussian 70 3))
(define weight (gaussian 150 20))
[height weight]))
But there's something wrong about this model: the height and weight are totally independent. In real life, a tall person is likely to weigh more; in our model, this is not the case.
Enter the bivariate Gaussian: in addition to a pair of means, it also takes a covariance matrix, a symmetric 2x2 matrix $\Sigma$, where $\sigma_{11}$ and $\sigma_{22}$ are the variances (squared standard deviations) of variables 1 and 2, but $\sigma_{12} = \sigma_{21}$ is the covariance of the variables with one another. A positive covariance indicates direct correlation; a negative covariance indicates inverse correlation. Zero covariance means the variables are completely independent. (Note: in general, zero covariance does not imply independence, but in a bivariate Gaussian model, it does.)
There are a number of ways to sample from a bivariate Gaussian. Below, we model the distribution as follows: first, we generate the first variable according to its mean and standard deviation (e.g., generate a height). Then, we figure out (based on how far that sample was from the mean, and the covariance), the expectation of the second variable (e.g., given the height we generated, what would we expect the weight to be?). Using the variance and covariance information, we can also calculate how much of a spread we expect to find around that expected value. Now that we have a new mean and covariance, we can sample the second variable (generate a weight). Here's the model:
(define conditional-biv-gaussian-mean
(gen [x1 [[mu1 mu2] [[sigma_11 sigma_12] [sigma_21 sigma_22]]]]
(+ mu2 (* (- x1 mu1) (/ sigma_12 sigma_11)))))
(define conditional-biv-gaussian-variance
(gen [x1 [[mu1 mu2] [[sigma_11 sigma_12] [sigma_21 sigma_22]]]]
(/ (- (* sigma_11 sigma_22) (* sigma_12 sigma_12)) sigma_11)))
(define conditional-biv-gaussian-density
(gen [x1 params]
(gen [x]
(exp (score-gaussian x [(conditional-biv-gaussian-mean x1 params)
(sqrt (conditional-biv-gaussian-variance x1 params))])))))
(define biv-gauss
(gen [[mu1 mu2] [[sigma_11 sigma_12] [sigma_21 sigma_22]]]
; draw the first variable normally
(define x1 (gaussian mu1 (sqrt sigma_11)))
; calculate the new mean and variance of the second variable
(define x2_mean
(conditional-biv-gaussian-mean
x1
[[mu1 mu2] [[sigma_11 sigma_12] [sigma_21 sigma_22]]]))
(define x2_var
(conditional-biv-gaussian-variance
x1
[[mu1 mu2] [[sigma_11 sigma_12] [sigma_21 sigma_22]]]))
; draw the second variable
(define x2 (gaussian x2_mean (sqrt x2_var)))
; return both
[x1 x2]))
(define person-v2
(gen []
; 9 = 3^2, 400 = 20^2; we pass variance instead of standard deviation here
; setting covariance to 0 (instead of 40) would exactly recover the person-v1 model above.
(biv-gauss [70 150] [[9 40] [40 400]])))
(define [height-adr weight-adr] (addresses-of (nth (infer :procedure person-v2 :inputs []) 1)))
Let's create a scatter plot to visualize the sorts of samples this generates:
(custom-scatter-plot "Bivariate Gaussian"
(replicate 1000 person-v2)
"cross" "white" "blue" [[60 80] [90 210]])
The "shape" of this data is elliptical: there is a rotated ellipse in which the points are concentrated, with density decreasing as we move away from the center of the ellipse along either of its axes. We can see this better by plotting the actual probability density; to do so, we write an assessor function that exactly computes the probability density of a specific point. Graphing that function in a contour plot can help us better understand how the idea of a "bell curve" generalizes to two dimensions:
(define biv-gaussian-density
(gen [[mu1 mu2] [[s11 s12] [s21 s22]]]
(gen [x y]
(define x-from-mu (- x mu1))
(define y-from-mu (- y mu2))
(/
(exp
(/
(- (+
(/ (* x-from-mu x-from-mu) s11)
(/ (* y-from-mu y-from-mu) s22))
(/ (* 2 s12 x-from-mu y-from-mu) (* s11 s22)))
(* -2 (/ (- (* s11 s22) (* s12 s12)) (* s11 s22)))))
(* 2 3.1415926 (sqrt (- (* s11 s22) (* s12 s12))))))))
(scatter-with-contours
"Bivariate Gaussian"
(replicate 100 person-v2); [] ;(replicate 100 (gen [] (biv-gauss [0 0] [[1/5 0] [0 1/5]])))
(biv-gaussian-density [70 150] [[9 40] [40 400]])
[[60 80] [90 210]])
In the plot, blue regions correspond to areas of low probability, and red regions to areas of high probability. Equiprobability contours are ellipses. The wider or taller the ellipses, the more uncertainty there is. The "off-center" rotation you see tells us that variables $x$ and $y$ are not independent: if we know one, it changes our guess about the other.
To get a feel for what the parameters of the bivariate Gaussian means, let's look at how the contour plots change as we change the parameters:
; Changing the means moves the ellipse
(scatter-with-contours
"Bivariate Gaussian"
[]
(biv-gaussian-density [65 120] [[9 40] [40 400]])
[[60 80] [90 210]])
; Increasing sigma_11 increases the spread in the x direction.
(scatter-with-contours
"Bivariate Gaussian"
[]
(biv-gaussian-density [70 150] [[18 40] [40 400]])
[[60 80] [90 210]])
; Increasing sigma_22 increases the spread in the y direction.
(scatter-with-contours
"Bivariate Gaussian"
[]
(biv-gaussian-density [70 150] [[9 40] [40 800]])
[[60 80] [90 210]])
; Decreasing sigma_12 decreases the correlation between the variables, leading to a "flatter" ellipse
(scatter-with-contours
"Bivariate Gaussian"
[]
(biv-gaussian-density [70 150] [[9 0] [0 400]])
[[60 80] [90 210]])
Note in this last plot that no matter what $x$ is, we induce the same distribution on $y$ (and vice versa).
Now, going back to our model of height and weight, suppose we know someone's weight and want to guess their height. We could use rejection sampling, but it's slow:
; Height samples for someone whose weight is 175
(define conditional-samples
(replicate 100
(gen []
(rejection-sample
person-v2
(gen [t] (< 174.5 (trace-get t weight-adr) 175.5))
(gen [t] [(trace-get t height-adr) (trace-get t weight-adr)])))))
(histogram-with-curve
"Exact posterior and rejection samples of height"
(map (gen [[h w]] h) conditional-samples)
(conditional-biv-gaussian-density 175 [[150 70] [[400 40] [40 9]]])
[60 80])
;; (scatter-with-contours
;; "Posterior samples of height"
;; conditional-samples
;; (biv-gaussian-density [70 150] [[9 40] [40 400]])
;; [[60 80] [174.4 175.6]])
In this histogram, we've also plotted the correct probability density for the posterior distribution. Generally, it is not tractable to compute exact posteriors—otherwise approximate inference techniques like the ones we've been exploring would be unnecessary. But for this simple case, of a bivariate Gaussian, it is possible, and we use the analytically computed curve to evaluate how well our sampling technique approximates the posterior distribution.
We can try to fix the speed problem with sampling/importance resampling. If we use only one particle, we are essentially disregarding the condition and sampling from the prior:
; Height samples for someone whose weight is 175
(define sir-samples-1-particle
(replicate 100
(gen []
(sample-importance-resample
person-v2
(trace-set (trace) weight-adr 175)
(gen [t] [(trace-get t height-adr) (trace-get t weight-adr)])
1))))
(histogram-with-curve
"Exact posterior and prior samples of height"
(map (gen [[h w]] h) sir-samples-1-particle)
(conditional-biv-gaussian-density 175 [[150 70] [[400 40] [40 9]]])
[60 80])
;; (scatter-with-contours
;; "Prior samples of height"
;; isr-samples-1-particle
;; (biv-gaussian-density [70 150] [[9 40] [40 400]])
;; [[60 80] [174.8 175.2]])
We can increase the number of particles and shift the distribution toward the true posterior, at a cost of spending more time per sample. See if you can play with the number of particles (below, 5) to achieve a good trade-off.
; Height samples for someone whose weight is 175
(define sir-samples-multiple-particles
(replicate 100
(gen []
(sample-importance-resample
person-v2
(trace-set (trace) weight-adr 175)
(gen [t] [(trace-get t height-adr) (trace-get t weight-adr)])
5))))
(histogram-with-curve
"Exact posterior and approximate importance samples of height"
(map (gen [[h w]] h) sir-samples-multiple-particles)
(conditional-biv-gaussian-density 175 [[150 70] [[400 40] [40 9]]])
[60 80])
;; (scatter-with-contours
;; "Approximate posterior samples of height"
;; isr-samples-10-particles
;; (biv-gaussian-density [70 150] [[9 40] [40 400]])
;; [[60 80] [174.8 175.2]])
In the case of the bivariate Gaussian, we happen to know exactly the mathematical form of the posterior. That is, for our person model, we can calculate the exact distribution that someone's height is drawn from when we know their weight. It is silly to use approximate algorithms when an exact one exists.
Metaprob is designed to be easily extensible, and to allow practitioners to encode any special knowledge like this directly into their models. It supports a class of user-defined procedures called "custom inference procedures," which specify not only a generative model, but also a custom "implementation" that overrides Metaprob's default behavior in the presence of target or intervention traces. In other words, we can specify custom behavior for when (infer ...) is called on our model.
; We use `inf` to define a custom inference procedure. It takes in:
; -- a name
; -- a model
; -- an "implementation"
(define custom-biv-gauss
(inf "biv-gauss" ; name
biv-gauss ; model
; implementation:
; accepts four arguments.
; 1. inputs to the model (biv gaussian params)
; 2. an intervention trace
; 3. a target trace
; 4. a boolean `out?` that specifies whether to return an output
; trace.
; returns three things:
; 1. the output of the model on these inputs
; 2. if out? is true, an output trace consistent with intervene and
; target.
; 3. a score, the log of a number proportional to the _ratio_ between the true posterior density
; and the density (at the sampled trace) of the distribution from which our implementation samples.
; Because we are sampling from the exact posterior here, the ratio is 1, and so
; we return log(1) == 0.
(gen [[[mu1 mu2] [[s11 s12] [s21 s22]]] intervene target out?]
; read x1 from intervention or target trace, or, if it's not specified
; but x2 is in the target trace, sample it from the exact posterior.
; if neither x1 nor x2 are constrained, sample a fresh x1 just like
; the biv-gauss model does.
(define x1
(cond
(trace-has? intervene '(0 "x1" "gaussian")) (trace-get intervene '(0 "x1" "gaussian"))
(trace-has? target '(0 "x1" "gaussian")) (trace-get target '(0 "x1" "gaussian"))
(trace-has? target '(3 "x2" "gaussian"))
(gaussian (conditional-biv-gaussian-mean (trace-get target '(3 "x2" "gaussian"))
[[mu2 mu1] [[s22 s12] [s21 s11]]])
(sqrt (conditional-biv-gaussian-variance (trace-get target '(3 "x2" "gaussian"))
[[mu2 mu1] [[s22 s12] [s21 s11]]])))
true (gaussian mu1 (sqrt s11))))
; Either use the provided value for x2, or sample conditionally based on x1.
(define x2
(cond
(trace-has? intervene '(3 "x2" "gaussian")) (trace-get intervene '(3 "x2" "gaussian"))
(trace-has? target '(3 "x2" "gaussian")) (trace-get target '(3 "x2" "gaussian"))
true
(gaussian (conditional-biv-gaussian-mean x1
[[mu1 mu2] [[s11 s12] [s21 s22]]])
(sqrt (conditional-biv-gaussian-variance x1
[[mu1 mu2] [[s11 s12] [s21 s22]]])))))
; Return values
[[x1 x2]
(if out?
(trace-set (trace-set (trace) '(0 "x1" "gaussian") x1) '(3 "x2" "gaussian") x2)
(trace))
0])))
Using this, sampling/importance resampling will "just work," no matter how many particles we use.
; rewrite person-v2 using custom-biv-gauss instead of biv-gauss
(define person-v3
(gen []
(custom-biv-gauss [70 150] [[9 40] [40 400]])))
(define [height-adr weight-adr] (addresses-of (nth (infer :procedure person-v3) 1)))
; Height samples for someone whose weight is 170
(define sir-samples-1-particle
(replicate 100
(gen []
(sample-importance-resample
person-v3
(trace-set (trace) weight-adr 175)
(gen [t] [(trace-get t height-adr) (trace-get t weight-adr)])
1))))
(histogram-with-curve
"Exact posterior with fast exact samples"
(map (gen [[h w]] h) sir-samples-1-particle)
(conditional-biv-gaussian-density 175 [[150 70] [[400 40] [40 9]]])
[60 80])
We can use the bivariate Gaussian to help implement a new model. We will model the following strange situation. Suppose we go to the circus and see an act in which two brothers have a circus act; they pretend to be one very tall person by standing on top of one another and wearing a large coat. From our seats, we can estimate (perhaps with some error) the total height of both brothers, but would like to infer their individual heights. Here is the model:
(define circus-brothers
(gen []
(define [h1 h2]
(custom-biv-gauss [70 70] [[9 5] [5 9]]))
(define total-height
(gaussian (+ h1 h2) 3))
total-height))
(define [h1-adr h2-adr total-adr] (addresses-of (nth (infer :procedure circus-brothers) 1)))
(scatter-with-contours
"Prior distribution of brothers' heights"
[]
(biv-gaussian-density [70 70] [[9 5] [5 9]])
[[60 80] [60 80]])
(define observed-height 155)
(scatter-with-contours
"Posterior distribution of brothers' heights"
[]
(biv-gaussian-density [(+ 17.02 (* 0.378 observed-height)) (+ 17.02 (* 0.378 observed-height))] [[3.7 -0.2966] [-0.2966 3.7]])
[[60 80] [60 80]])
This shape should make some sense: because 155 is a bit higher than what we'd ordinarily expect, the distribution has moved up and to the right. It's also straightened out: originally, knowing that one brother was taller would also make us think the second brother was taller, because brothers tend to have similar heights. But now, given that their total height is constrained to be somewhere near 155, learning that one brother is very tall would lead us to believe the other brother must be a bit shorter (and vice versa).
Let's plot some samples, using all our techniques for sampling from a conditional.
(define extract-heights
(gen [t] [(trace-get t h1-adr) (trace-get t h2-adr)]))
; From the prior
(define prior-samples
(replicate 100
(gen []
(sample-importance-resample
circus-brothers
(trace-set (trace) total-adr observed-height)
extract-heights
1))))
(scatter-with-contours
"Prior samples of brothers' heights"
prior-samples
(biv-gaussian-density [(+ 17.02 (* 0.378 observed-height)) (+ 17.02 (* 0.378 observed-height))] [[3.7 -0.2966] [-0.2966 3.7]])
[[60 80] [60 80]])
; From the approximate posterior
(define approx-posterior-samples
(replicate 100
(gen []
(sample-importance-resample
circus-brothers
(trace-set (trace) total-adr observed-height)
extract-heights
10))))
(scatter-with-contours
"Approx. posterior samples of brothers' heights"
approx-posterior-samples
(biv-gaussian-density [(+ 17.02 (* 0.378 observed-height)) (+ 17.02 (* 0.378 observed-height))] [[3.7 -0.2966] [-0.2966 3.7]])
[[60 80] [60 80]])
; From the posterior but slowly
(define posterior-samples
(replicate 100
(gen []
(rejection-sample
circus-brothers
(gen [t]
(< -0.5
(- observed-height (trace-get t total-adr))
0.5))
extract-heights))))
(scatter-with-contours
"Posterior samples of brothers' heights"
posterior-samples
(biv-gaussian-density [(+ 17.02 (* 0.378 observed-height)) (+ 17.02 (* 0.378 observed-height))] [[3.7 -0.2966] [-0.2966 3.7]])
[[60 80] [60 80]])
One technique we can use to get more accurate samples than sampling/importance resampling in less time than rejection sampling is to use "Markov Chain Monte Carlo" (MCMC) methods. These methods involve starting with a single sample and iteratively changing it, in such a way that eventually, each successive iteration yields a sample from the posterior. There are many ways to do this, but below, we use an algorithm called Single-Site Metropolis Hastings, which (in our case) will iteratively propose changes to one brother's height at a time, then randomly deciding to either accept or reject the proposal. (The probability with which a proposal is accepted depends on a couple factors, including how likely our constraint is given the proposal.)
(define initial-sample
((infer :procedure circus-brothers
:target-trace (trace-set (trace) total-adr 155)) 1))
(define mh-traj
(clojure.core/reduce
(gen [samples _]
(clojure.core/cons
(immutable-single-site-metropolis-hastings-step
circus-brothers []
(first samples) (list total-adr))
samples))
(list initial-sample)
(range 500)))
(scatter-with-contours
"Posterior samples of brothers' heights"
(map extract-heights mh-traj)
(biv-gaussian-density [(+ 17.02 (* 0.378 observed-height)) (+ 17.02 (* 0.378 observed-height))] [[3.7 -0.2966] [-0.2966 3.7]])
[[60 80] [60 80]])
Click Play in the animated plot below to view the trajectory. It flashes red whenever the trajectory stays stationary because of a rejected proposal.
(mh-animation "Posterior samples of brothers' heights"
(clojure.core/reverse (map extract-heights mh-traj))
[[60 80] [60 80]])
Implement a custom inference procedure for circus-brothers, using the fact that conditioned on a certain observed total height $h$, the brothers' individual heights are drawn from a bivariate Gaussian:
Plot importance samples and MH samples from this model.
(smart-plot-trace (nth (infer :procedure circus-brothers) 1))
(smart-plot-trace (nth (infer :procedure custom-biv-gauss :inputs [[0 0] [[1 0] [0 1]]]) 1))
[h1-adr h2-adr total-adr]
(define trace-subtrace-or-empty
(gen [t a]
(if (trace-has-subtrace? t a) (trace-subtrace t a) (trace))))
(define custom-circus-brothers
(inf "circus-brothers"
circus-brothers
(gen [[] intervene target out?]
; h1 h2
(define [[h1 h2] o _]
(infer :procedure custom-biv-gauss
:intervention-trace (trace-subtrace-or-empty intervene '(0 "definiens" "custom-biv-gauss"))
:target-trace (trace-subtrace-or-empty target '(0 "definiens" "custom-biv-gauss"))
:inputs
(if (trace-has? target total-adr)
[[(+ 17.02 (* 0.378 (trace-get target total-adr))) (+ 17.02 (* 0.378 (trace-get target total-adr)))]
[[3.7 -0.2966] [-0.2966 3.7]]]
[[70 70]
[[9 5] [5 9]]])))
; total height
(define total-height
(cond
(trace-has? intervene total-adr) (trace-get intervene total-adr)
(trace-has? target total-adr) (trace-get target total-adr)
true (gaussian (+ h1 h2) 3)))
[total-height (trace-set-values-at (trace)
total-adr total-height h1-adr h1 h2-adr h2)
0])))